package nvd

import (
	"encoding/json"
	"fmt"
	"runtime"
	"strings"
	"time"

	"net/http"
	"net/url"

	"cve-db/models"
	"cve-db/fetcher"
	"cve-db/util"
	"cve-db/config"
	"strconv"
	"github.com/hashicorp/go-version"
	"cve-db/db"
	log "github.com/sirupsen/logrus"
	"github.com/k0kubun/pp"
	"github.com/PuerkitoBio/goquery"
)

// 从nvd中获取cve漏洞信息
func FetchConvert(metas []models.FeedMeta) (cves []models.CveDetail, err error) {
	reqs := []fetcher.FetchRequest{}
	for _, meta := range metas {
		reqs = append(reqs, fetcher.FetchRequest{
			URL:  meta.URL,
			GZIP: true,
		})
	}

	results, err := fetcher.FetchFeedDatas(reqs)
	if err != nil {
		return nil,
			fmt.Errorf("Failed to fetch. err: %s", err)
	}

	errs := []error{}
	for _, res := range results {
		nvd := NvdJSON{}
		if err = json.Unmarshal(res.Body, &nvd); err != nil {
			return nil, fmt.Errorf(
				"Failed to unmarshal. url: %s, err: %s",
				res.URL, err)
		}
		cs, err := convert(nvd.CveItems)
		if err != nil {
			errs = append(errs, err)
		}
		cves = append(cves, cs...)
	}
	if 0 < len(errs) {
		return nil, fmt.Errorf("%s", errs)
	}
	return cves, nil
}

func convert(items []CveItem) (cves []models.CveDetail, err error) {
	reqChan := make(chan CveItem, len(items))
	resChan := make(chan *models.CveDetail, len(items))
	errChan := make(chan error, len(items))
	defer close(reqChan)
	defer close(resChan)
	defer close(errChan)

	go func() {
		for _, item := range items {
			reqChan <- item
		}
	}()

	concurrency := 10
	tasks := util.GenWorkers(concurrency)
	for range items {
		tasks <- func() {
			req := <-reqChan
			cve, err := convertToModel(&req)
			if err != nil {
				errChan <- err
				return
			}
			resChan <- cve
		}
	}

	errs := []error{}
	timeout := time.After(10 * 60 * time.Second)
	for range items {
		select {
		case res := <-resChan:
			cves = append(cves, *res)
		case err := <-errChan:
			errs = append(errs, err)
		case <-timeout:
			return nil, fmt.Errorf("Timeout Fetching")
		}
	}
	if 0 < len(errs) {
		return nil, fmt.Errorf("%s", errs)
	}
	return cves, nil
}

// NvdJSON是NVD JSON的结构体
// https://scap.nist.gov/schema/nvd/feed/0.1/nvd_cve_feed_json_0.1_beta.schema
type NvdJSON struct {
	CveDataType         string    `json:"CVE_data_type"`
	CveDataFormat       string    `json:"CVE_data_format"`
	CveDataVersion      string    `json:"CVE_data_version"`
	CveDataNumberOfCVEs string    `json:"CVE_data_numberOfCVEs"`
	CveDataTimestamp    string    `json:"CVE_data_timestamp"`
	CveItems            []CveItem `json:"CVE_Items"`
}

// CveItem是NvdJSON>CveItems的结构体
type CveItem struct {
	Cve struct {
		DataType    string `json:"data_type"`
		DataFormat  string `json:"data_format"`
		DataVersion string `json:"data_version"`
		CveDataMeta struct {
			ID       string `json:"ID"`
			ASSIGNER string `json:"ASSIGNER"`
		} `json:"CVE_data_meta"`
		Affects struct {
			Vendor struct {
				VendorData []struct {
					VendorName string `json:"vendor_name"`
					Product    struct {
						ProductData []struct {
							ProductName string `json:"product_name"`
							Version     struct {
								VersionData []struct {
									VersionValue string `json:"version_value"`
								} `json:"version_data"`
							} `json:"version"`
						} `json:"product_data"`
					} `json:"product"`
				} `json:"vendor_data"`
			} `json:"vendor"`
		} `json:"affects"`
		Problemtype struct {
			ProblemtypeData []struct {
				Description []struct {
					Lang  string `json:"lang"`
					Value string `json:"value"`
				} `json:"description"`
			} `json:"problemtype_data"`
		} `json:"problemtype"`
		References struct {
			ReferenceData []struct {
				URL  string   `json:"url"`
				TAGS []string `json:"tags"`
			} `json:"reference_data"`
		} `json:"references"`
		Description struct {
			DescriptionData []struct {
				Lang  string `json:"lang"`
				Value string `json:"value"`
			} `json:"description_data"`
		} `json:"description"`
	} `json:"cve"`
	Configurations struct {
		CveDataVersion string `json:"CVE_data_version"`
		Nodes          []struct {
			Operator string `json:"operator"`
			Negate   bool   `json:"negate"`
			Cpes     []struct {
				Vulnerable            bool   `json:"vulnerable"`
				Cpe23URI              string `json:"cpe23Uri"`
				VersionStartExcluding string `json:"versionStartExcluding"`
				VersionStartIncluding string `json:"versionStartIncluding"`
				VersionEndExcluding   string `json:"versionEndExcluding"`
				VersionEndIncluding   string `json:"versionEndIncluding"`
			} `json:"cpe_match"`
			Children []struct {
				Operator string `json:"operator"`
				Cpes     []struct {
					Vulnerable            bool   `json:"vulnerable"`
					Cpe23URI              string `json:"cpe23Uri"`
					VersionStartExcluding string `json:"versionStartExcluding"`
					VersionStartIncluding string `json:"versionStartIncluding"`
					VersionEndExcluding   string `json:"versionEndExcluding"`
					VersionEndIncluding   string `json:"versionEndIncluding"`
				} `json:"cpe_match"`
			} `json:"children,omitempty"`
		} `json:"nodes"`
	} `json:"configurations"`
	Impact struct {
		BaseMetricV3 struct {
			CvssV3 struct {
				Version               string  `json:"version"`
				VectorString          string  `json:"vectorString"`
				AttackVector          string  `json:"attackVector"`
				AttackComplexity      string  `json:"attackComplexity"`
				PrivilegesRequired    string  `json:"privilegesRequired"`
				UserInteraction       string  `json:"userInteraction"`
				Scope                 string  `json:"scope"`
				ConfidentialityImpact string  `json:"confidentialityImpact"`
				IntegrityImpact       string  `json:"integrityImpact"`
				AvailabilityImpact    string  `json:"availabilityImpact"`
				BaseScore             float64 `json:"baseScore"`
				BaseSeverity          string  `json:"baseSeverity"`
			} `json:"cvssV3"`
			ExploitabilityScore float64 `json:"exploitabilityScore"`
			ImpactScore         float64 `json:"impactScore"`
		} `json:"baseMetricV3"`
		BaseMetricV2 struct {
			CvssV2 struct {
				Version               string  `json:"version"`
				VectorString          string  `json:"vectorString"`
				AccessVector          string  `json:"accessVector"`
				AccessComplexity      string  `json:"accessComplexity"`
				Authentication        string  `json:"authentication"`
				ConfidentialityImpact string  `json:"confidentialityImpact"`
				IntegrityImpact       string  `json:"integrityImpact"`
				AvailabilityImpact    string  `json:"availabilityImpact"`
				BaseScore             float64 `json:"baseScore"`
			} `json:"cvssV2"`
			Severity                string  `json:"severity"`
			ExploitabilityScore     float64 `json:"exploitabilityScore"`
			ImpactScore             float64 `json:"impactScore"`
			ObtainAllPrivilege      bool    `json:"obtainAllPrivilege"`
			ObtainUserPrivilege     bool    `json:"obtainUserPrivilege"`
			ObtainOtherPrivilege    bool    `json:"obtainOtherPrivilege"`
			UserInteractionRequired bool    `json:"userInteractionRequired"`
		} `json:"baseMetricV2"`
	} `json:"impact"`
	PublishedDate    string `json:"publishedDate"`
	LastModifiedDate string `json:"lastModifiedDate"`
}

// CertLink是用于临时存储reference URL的结构体
type CertLink struct {
	Link string
}

// 将Nvd JSON转化为模型结构体
func convertToModel(item *CveItem) (*models.CveDetail, error) {
	//References
	refs := []models.Reference{}
	for _, r := range item.Cve.References.ReferenceData {
		ref := models.Reference{
			Link: r.URL,
		}
		refs = append(refs, ref)
	}

	// Certs
	links := []CertLink{}
	for _, ref := range item.Cve.References.ReferenceData {
		tag := strings.Join(ref.TAGS, " ")
		if strings.Contains(ref.URL, "ncas/alerts") || strings.Contains(ref.URL, "cas/techalerts") || strings.Contains(tag, "US Government Resource") {
			if !strings.HasPrefix(ref.URL, "http") || strings.HasSuffix(ref.URL, ".pdf") {
				continue
			}
			links = append(links, CertLink{
				Link: ref.URL,
			})
		}
	}

	certs, err := collectCertLinks(links)
	if err != nil {
		return nil,
			fmt.Errorf("Failed to collect links. err: %s", err)
	}

	// Cwes
	cwes := []models.Cwe{}
	for _, data := range item.Cve.Problemtype.ProblemtypeData {
		for _, desc := range data.Description {
			cwes = append(cwes, models.Cwe{
				CweID: desc.Value,
			})
		}
	}

	// Affects
	affects := []models.Affect{}
	for _, vendor := range item.Cve.Affects.Vendor.VendorData {
		for _, prod := range vendor.Product.ProductData {
			for _, version := range prod.Version.VersionData {
				affects = append(affects, models.Affect{
					Vendor:  vendor.VendorName,
					Product: prod.ProductName,
					Version: version.VersionValue,
				})
			}
		}
	}

	// Traverse Cpe, EnvCpe
	cpes := []models.Cpe{}
	for _, node := range item.Configurations.Nodes {
		if node.Negate {
			continue
		}

		nodeCpes := []models.Cpe{}
		for _, cpe := range node.Cpes {
			if !cpe.Vulnerable {
				// CVE-2017-14492 and CVE-2017-8581 has a cpe that has vulenrable:false.
				// But these vulnerable: false cpe is also vulnerable...
				// So, ignore the vulerable flag of this layer(under nodes>cpe)
			}
			cpeBase, err := fetcher.ParseCpeURI(cpe.Cpe23URI)
			if err != nil {
				// logging only
				log.Infof("Failed to parse CpeURI %s: %s", cpe.Cpe23URI, err)
				continue
			}
			cpeBase.VersionStartExcluding = cpe.VersionStartExcluding
			cpeBase.VersionStartIncluding = cpe.VersionStartIncluding
			cpeBase.VersionEndExcluding = cpe.VersionEndExcluding
			cpeBase.VersionEndIncluding = cpe.VersionEndIncluding
			nodeCpes = append(nodeCpes, models.Cpe{
				CpeBase: *cpeBase,
			})
			if !checkIfVersionParsable(cpeBase) {
				return nil, fmt.Errorf(
					"Version parse err. Please add a issue on [GitHub](https://github.com/kotakanbe/go-cve-dictionary/issues/new). Title: %s, Content:%s",
					item.Cve.CveDataMeta.ID,
					pp.Sprintf("%v", *item),
				)
			}
		}
		for _, child := range node.Children {
			for _, cpe := range child.Cpes {
				if cpe.Vulnerable {
					cpeBase, err := fetcher.ParseCpeURI(cpe.Cpe23URI)
					if err != nil {
						return nil, err
					}
					cpeBase.VersionStartExcluding = cpe.VersionStartExcluding
					cpeBase.VersionStartIncluding = cpe.VersionStartIncluding
					cpeBase.VersionEndExcluding = cpe.VersionEndExcluding
					cpeBase.VersionEndIncluding = cpe.VersionEndIncluding
					nodeCpes = append(nodeCpes, models.Cpe{
						CpeBase: *cpeBase,
					})
					if !checkIfVersionParsable(cpeBase) {
						return nil, fmt.Errorf(
							"Version parse err. Please add a issue on [GitHub](https://github.com/kotakanbe/go-cve-dictionary/issues/new). Title: %s, Content:%s",
							item.Cve.CveDataMeta.ID,
							pp.Sprintf("%v", *item),
						)
					}
				} else {
					if node.Operator == "AND" {
						for i, c := range nodeCpes {
							cpeBase, err := fetcher.ParseCpeURI(cpe.Cpe23URI)
							if err != nil {
								return nil, err
							}
							cpeBase.VersionStartExcluding = cpe.VersionStartExcluding
							cpeBase.VersionStartIncluding = cpe.VersionStartIncluding
							cpeBase.VersionEndExcluding = cpe.VersionEndExcluding
							cpeBase.VersionEndIncluding = cpe.VersionEndIncluding
							nodeCpes[i].EnvCpes = append(c.EnvCpes, models.EnvCpe{
								CpeBase: *cpeBase,
							})

							if !checkIfVersionParsable(cpeBase) {
								return nil, fmt.Errorf(
									"Please add a issue on [GitHub](https://github.com/kotakanbe/go-cve-dictionary/issues/new). Title: Version parse err: %s, Content:%s",
									item.Cve.CveDataMeta.ID,
									pp.Sprintf("%v", *item),
								)
							}
						}
					}
				}
			}
		}
		cpes = append(cpes, nodeCpes...)
	}

	// Description
	descs := []models.Description{}
	for _, desc := range item.Cve.Description.DescriptionData {
		descs = append(descs, models.Description{
			Lang:  desc.Lang,
			Value: desc.Value,
		})
	}

	publish, err := parseNvdJSONTime(item.PublishedDate)
	if err != nil {
		return nil, err
	}
	modified, err := parseNvdJSONTime(item.LastModifiedDate)
	if err != nil {
		return nil, err
	}
	c2 := item.Impact.BaseMetricV2
	c3 := item.Impact.BaseMetricV3

	return &models.CveDetail{
		CveID: item.Cve.CveDataMeta.ID,
		NvdJSON: &models.NvdJSON{
			CveID:        item.Cve.CveDataMeta.ID,
			Descriptions: descs,
			Cvss2: models.Cvss2Extra{
				Cvss2: models.Cvss2{
					VectorString:          c2.CvssV2.VectorString,
					AccessVector:          c2.CvssV2.AccessVector,
					AccessComplexity:      c2.CvssV2.AccessComplexity,
					Authentication:        c2.CvssV2.Authentication,
					ConfidentialityImpact: c2.CvssV2.ConfidentialityImpact,
					IntegrityImpact:       c2.CvssV2.IntegrityImpact,
					AvailabilityImpact:    c2.CvssV2.AvailabilityImpact,
					BaseScore:             c2.CvssV2.BaseScore,
					Severity:              c2.Severity,
				},
				ExploitabilityScore:     c2.ExploitabilityScore,
				ImpactScore:             c2.ImpactScore,
				ObtainAllPrivilege:      c2.ObtainAllPrivilege,
				ObtainUserPrivilege:     c2.ObtainUserPrivilege,
				ObtainOtherPrivilege:    c2.ObtainOtherPrivilege,
				UserInteractionRequired: c2.UserInteractionRequired,
			},
			Cvss3: models.Cvss3{
				VectorString:          c3.CvssV3.VectorString,
				AttackVector:          c3.CvssV3.AttackVector,
				AttackComplexity:      c3.CvssV3.AttackComplexity,
				PrivilegesRequired:    c3.CvssV3.PrivilegesRequired,
				UserInteraction:       c3.CvssV3.UserInteraction,
				Scope:                 c3.CvssV3.Scope,
				ConfidentialityImpact: c3.CvssV3.ConfidentialityImpact,
				IntegrityImpact:       c3.CvssV3.IntegrityImpact,
				AvailabilityImpact:    c3.CvssV3.AvailabilityImpact,
				BaseScore:             c3.CvssV3.BaseScore,
				BaseSeverity:          c3.CvssV3.BaseSeverity,
				ExploitabilityScore:   c3.ExploitabilityScore,
				ImpactScore:           c3.ImpactScore,
			},
			Cwes:             cwes,
			Cpes:             cpes,
			References:       refs,
			Affects:          affects,
			Certs:            certs,
			PublishedDate:    publish,
			LastModifiedDate: modified,
		},
	}, nil
}

func collectCertLinks(links []CertLink) (certs []models.Cert, err error) {
	var proxyURL *url.URL
	httpCilent := &http.Client{}
	if config.Conf.HTTPProxy != "" {
		if proxyURL, err = url.Parse(config.Conf.HTTPProxy); err != nil {
			return nil, fmt.Errorf("failed to parse proxy url: %s", err)
		}
		httpCilent = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)}}
	}

	reqChan := make(chan string, len(links))
	resChan := make(chan models.Cert, len(links))
	errChan := make(chan error, len(links))
	defer close(reqChan)
	defer close(resChan)
	defer close(errChan)

	go func() {
		for _, ref := range links {
			reqChan <- ref.Link
		}
	}()

	concurrency := runtime.NumCPU()
	tasks := util.GenWorkers(concurrency)
	for _, l := range links {
		tasks <- func() {
			url := <-reqChan
			log.Debugf("Fetching %s", url)

			req, err := http.NewRequest("GET", url, nil)
			if err != nil {
				log.Debugf("Failed to get %s: err: %s", url, err)
				errChan <- err
				return
			}

			res, err := httpCilent.Do(req)
			if err != nil {
				log.Debugf("Failed to get %s: err: %s", url, err)
				errChan <- err
				return
			}
			defer res.Body.Close()

			doc, err := goquery.NewDocumentFromReader(res.Body)
			if err != nil {
				log.Debugf("Failed to get %s: err: %s", url, err)
				errChan <- err
				return
			}
			var title string
			if strings.Contains(l.Link, "kb.cert.org") || strings.Contains(l.Link, "kaspersky.com") {
				title = doc.Find("title").Text()
			} else {
				title = doc.Find("#page-sub-title").Text()
			}
			res.Body.Close()
			resChan <- models.Cert{
				Title: title,
				Link:  l.Link,
			}
		}
	}

	timeout := time.After(10 * 30 * time.Second)
	for range links {
		select {
		case res := <-resChan:
			certs = append(certs, res)
		case <-errChan:
		case <-timeout:
			log.Debugf("timeout fetching")
		}
	}
	return certs, nil
}

func checkIfVersionParsable(cpeBase *models.CpeBase) bool {
	if cpeBase.Version != "ANY" && cpeBase.Version != "NA" {
		vers := []string{cpeBase.VersionStartExcluding,
			cpeBase.VersionStartIncluding,
			cpeBase.VersionEndIncluding,
			cpeBase.VersionEndExcluding}
		for _, v := range vers {
			if v == "" {
				continue
			}
			v := strings.Replace(v, `\`, "", -1)
			if _, err := version.NewVersion(v); err != nil {
				return false
			}
		}
	}
	return true
}

func parseNvdJSONTime(strtime string) (t time.Time, err error) {
	layout := "2006-01-02T15:04Z"
	t, err = time.Parse(layout, strtime)
	if err != nil {
		return t, fmt.Errorf("Failed to parse time, time: %s, err: %s",
			strtime, err)
	}
	return
}

// 从NVD获取CVE的元信息
func FetchLatestFeedMeta(driver db.DB, years []int) (metas []models.FeedMeta, err error) {
	reqs := []fetcher.FetchRequest{}
	for _, year := range years {
		urls := MakeNvdMetaURLs(year)
		for _, url := range urls {
			reqs = append(reqs, fetcher.FetchRequest{
				Year: year,
				URL:  url,
			})
		}
	}
	results, err := fetcher.FetchFeedDatas(reqs)
	if err != nil {
		log.Errorf("Failed to fetch. err: %s", err)
		return nil, err
	}

	for _, res := range results {
		str := string(res.Body)
		ss := strings.Split(str, "\r\n")
		if len(ss) != 6 {
			continue
		}
		hash := ss[4]

		// 替换元数据url为cve漏洞详情的url
		url := strings.Replace(res.URL, ".meta", ".json.gz", -1)

		meta, err := driver.GetFetchedFeedMeta(url)
		if err != nil {
			log.Errorf("Failed to get meta: %d, err: %s", res.Year, err)
			return nil, err
		}
		meta.URL = url
		meta.LatestHash = hash
		meta.LatestLastModifiedDate = strings.TrimPrefix(ss[0], "lastModifiedDate:")
		metas = append(metas, *meta)
	}
	return metas, nil
}

// 返回NVD Feed meta的urls
func MakeNvdMetaURLs(year int) (urls []string) {
	formatTemplate := ""

	// https: //nvd.nist.gov/vuln/data-feeds#JSON_FEED
	formatTemplate = "https://nvd.nist.gov/feeds/json/cve/1.0/nvdcve-1.0-%s.meta"

	if year == config.Latest {
		for _, name := range []string{"modified", "recent"} {
			urls = append(urls, fmt.Sprintf(formatTemplate, name))
		}
	} else {
		feed := strconv.Itoa(year)
		urls = append(urls, fmt.Sprintf(formatTemplate, feed))
	}
	return urls
}

// 更新元数据到数据库表中
func UpdateMeta(driver db.DB, metas []models.FeedMeta) error {
	for _, meta := range metas {
		meta.Hash = meta.LatestHash
		meta.LastModifiedDate = meta.LatestLastModifiedDate
		err := driver.UpsertFeedHash(meta)
		if err != nil {
			return fmt.Errorf("Failed to updte meta: %s, err: %s",
				meta.URL, err)
		}
	}
	return nil
}